#include "param.h"
#include "types.h"
#include "memlayout.h"
#include "elf.h"
#include "riscv.h"
#include "defs.h"
#include "fs.h"

/*
 * the kernel's page table.
 * 全局的内核页表
 */
pagetable_t kernel_pagetable;

extern char etext[];  // kernel.ld sets this to end of kernel code.

extern char trampoline[];  // trampoline.S

/*
 * create a direct-map page table for the kernel.
 */
void kvminit() {
  kernel_pagetable = (pagetable_t)kalloc();
  memset(kernel_pagetable, 0, PGSIZE);

  // uart registers
  kvmmap(UART0, UART0, PGSIZE, PTE_R | PTE_W);

  // virtio mmio disk interface
  kvmmap(VIRTIO0, VIRTIO0, PGSIZE, PTE_R | PTE_W);

  // CLINT
  kvmmap(CLINT, CLINT, 0x10000, PTE_R | PTE_W);

  // PLIC
  kvmmap(PLIC, PLIC, 0x400000, PTE_R | PTE_W);

  // map kernel text executable and read-only.
  kvmmap(KERNBASE, KERNBASE, (uint64)etext - KERNBASE, PTE_R | PTE_X);

  // map kernel data and the physical RAM we'll make use of.
  kvmmap((uint64)etext, (uint64)etext, PHYSTOP - (uint64)etext, PTE_R | PTE_W);

  // map the trampoline for trap entry/exit to
  // the highest virtual address in the kernel.
  kvmmap(TRAMPOLINE, (uint64)trampoline, PGSIZE, PTE_R | PTE_X);
}

void kvmmap_new(pagetable_t k_pagetable, uint64 va, uint64 pa, uint64 sz, int perm) {
  if (mappages(k_pagetable, va, sz, pa, perm) != 0) panic("kvmmap_new");
}

/* 直接创建一个新的内核页表，并将其地址k_pagetable返回 */
pagetable_t vmcreate(){
  pagetable_t k_pagetable = (pagetable_t)kalloc();

  memset(k_pagetable, 0, PGSIZE);

  kvmmap_new(k_pagetable, UART0, UART0, PGSIZE, PTE_R | PTE_W);

  kvmmap_new(k_pagetable, VIRTIO0, VIRTIO0, PGSIZE, PTE_R | PTE_W);

  kvmmap_new(k_pagetable, PLIC, PLIC, 0x400000, PTE_R | PTE_W);

  kvmmap_new(k_pagetable, KERNBASE, KERNBASE, (uint64)etext - KERNBASE, PTE_R | PTE_X);

  kvmmap_new(k_pagetable, (uint64)etext, (uint64)etext, PHYSTOP - (uint64)etext, PTE_R | PTE_W);

  kvmmap_new(k_pagetable, TRAMPOLINE, (uint64)trampoline, PGSIZE, PTE_R | PTE_X);

  return k_pagetable;
}


// Switch h/w page table register to the kernel's page table,
// and enable paging.
void kvminithart() {
  // 使用 w_satp 指令将内核的页表地址加载到 satp 寄存器中，切换到内核的虚拟地址空间。
  w_satp(MAKE_SATP(kernel_pagetable));
  // 使用 sfence_vma() 指令刷新 TLB，确保所有的虚拟内存地址到物理地址的映射都使用新的页表。
  sfence_vma();
}

// Return the address of the PTE in page table pagetable
// that corresponds to virtual address va.  If alloc!=0,
// create any required page-table pages.
//
// The risc-v Sv39 scheme has three levels of page-table
// pages. A page-table page contains 512 64-bit PTEs.
// A 64-bit virtual address is split into five fields:
//   39..63 -- must be zero.
//   30..38 -- 9 bits of level-2 index.
//   21..29 -- 9 bits of level-1 index.
//   12..20 -- 9 bits of level-0 index.
//    0..11 -- 12 bits of byte offset within the page.
pte_t *walk(pagetable_t pagetable, uint64 va, int alloc) {
  if (va >= MAXVA) panic("walk");

  for (int level = 2; level > 0; level--) {
    pte_t *pte = &pagetable[PX(level, va)];
    if (*pte & PTE_V) {
      pagetable = (pagetable_t)PTE2PA(*pte);
    } else {
      if (!alloc || (pagetable = (pde_t *)kalloc()) == 0) return 0;
      memset(pagetable, 0, PGSIZE);
      *pte = PA2PTE(pagetable) | PTE_V;
    }
  }
  return &pagetable[PX(0, va)];
}

// Look up a virtual address, return the physical address,
// or 0 if not mapped.
// Can only be used to look up user pages.
uint64 walkaddr(pagetable_t pagetable, uint64 va) {
  pte_t *pte;
  uint64 pa;

  if (va >= MAXVA) return 0;

  pte = walk(pagetable, va, 0);
  if (pte == 0) return 0;
  if ((*pte & PTE_V) == 0) return 0;
  if ((*pte & PTE_U) == 0) return 0;
  pa = PTE2PA(*pte);
  return pa;
}

// add a mapping to the kernel page table.
// only used when booting.
// does not flush TLB or enable paging.
// kvmmap 是一个用于内核页表映射的函数，通常用于将一段虚拟地址范围 va 到物理地址范围 pa 的映射关系建立起来，
// 并设置对应的页表权限。
void kvmmap(uint64 va, uint64 pa, uint64 sz, int perm) {
  if (mappages(kernel_pagetable, va, sz, pa, perm) != 0) panic("kvmmap");
}

// translate a kernel virtual address to
// a physical address. only needed for
// addresses on the stack.
// assumes va is page aligned.
uint64 kvmpa(uint64 va) {
  uint64 off = va % PGSIZE;
  pte_t *pte;
  uint64 pa;

  pte = walk(kernel_pagetable, va, 0);
  if (pte == 0) panic("kvmpa");
  if ((*pte & PTE_V) == 0) panic("kvmpa");
  pa = PTE2PA(*pte);
  return pa + off;
}

// Create PTEs for virtual addresses starting at va that refer to
// physical addresses starting at pa. va and size might not
// be page-aligned. Returns 0 on success, -1 if walk() couldn't
// allocate a needed page-table page.
int mappages(pagetable_t pagetable, uint64 va, uint64 size, uint64 pa, int perm) {
  uint64 a, last;
  pte_t *pte;

  a = PGROUNDDOWN(va);
  last = PGROUNDDOWN(va + size - 1);
  for (;;) {
    if ((pte = walk(pagetable, a, 1)) == 0) return -1;
    if (*pte & PTE_V) panic("remap");
    *pte = PA2PTE(pa) | perm | PTE_V;
    if (a == last) break;
    a += PGSIZE;
    pa += PGSIZE;
  }
  return 0;
}

// Remove npages of mappings starting from va. va must be
// page-aligned. The mappings must exist.
// Optionally free the physical memory.
void uvmunmap(pagetable_t pagetable, uint64 va, uint64 npages, int do_free) {
  uint64 a;
  pte_t *pte;

  if ((va % PGSIZE) != 0) panic("uvmunmap: not aligned");

  for (a = va; a < va + npages * PGSIZE; a += PGSIZE) {
    if ((pte = walk(pagetable, a, 0)) == 0) panic("uvmunmap: walk");
    if ((*pte & PTE_V) == 0) panic("uvmunmap: not mapped");
    if (PTE_FLAGS(*pte) == PTE_V) panic("uvmunmap: not a leaf");
    if (do_free) {
      uint64 pa = PTE2PA(*pte);
      kfree((void *)pa);
    }
    *pte = 0;
  }
}

// create an empty user page table.
// returns 0 if out of memory.
pagetable_t uvmcreate() {
  pagetable_t pagetable;
  pagetable = (pagetable_t)kalloc();
  if (pagetable == 0) return 0;
  memset(pagetable, 0, PGSIZE);
  return pagetable;
}

// Load the user initcode into address 0 of pagetable,
// for the very first process.
// sz must be less than a page.
void uvminit(pagetable_t pagetable, uchar *src, uint sz) {
  char *mem;

  if (sz >= PGSIZE) panic("inituvm: more than a page");
  mem = kalloc();
  memset(mem, 0, PGSIZE);
  mappages(pagetable, 0, PGSIZE, (uint64)mem, PTE_W | PTE_R | PTE_X | PTE_U);
  memmove(mem, src, sz);
}

// Allocate PTEs and physical memory to grow process from oldsz to
// newsz, which need not be page aligned.  Returns new size or 0 on error.
uint64 uvmalloc(pagetable_t pagetable, uint64 oldsz, uint64 newsz) {
  char *mem;
  uint64 a;

  if (newsz < oldsz) return oldsz;

  oldsz = PGROUNDUP(oldsz);
  for (a = oldsz; a < newsz; a += PGSIZE) {
    mem = kalloc();
    if (mem == 0) {
      uvmdealloc(pagetable, a, oldsz);
      return 0;
    }
    memset(mem, 0, PGSIZE);
    if (mappages(pagetable, a, PGSIZE, (uint64)mem, PTE_W | PTE_X | PTE_R | PTE_U) != 0) {
      kfree(mem);
      uvmdealloc(pagetable, a, oldsz);
      return 0;
    }
  }
  return newsz;
}

// Deallocate user pages to bring the process size from oldsz to
// newsz.  oldsz and newsz need not be page-aligned, nor does newsz
// need to be less than oldsz.  oldsz can be larger than the actual
// process size.  Returns the new process size.
uint64 uvmdealloc(pagetable_t pagetable, uint64 oldsz, uint64 newsz) {
  if (newsz >= oldsz) return oldsz;

  if (PGROUNDUP(newsz) < PGROUNDUP(oldsz)) {
    int npages = (PGROUNDUP(oldsz) - PGROUNDUP(newsz)) / PGSIZE;
    uvmunmap(pagetable, PGROUNDUP(newsz), npages, 1);
  }

  return newsz;
}

// Recursively free page-table pages.
// All leaf mappings must already have been removed.
void freewalk(pagetable_t pagetable) {
  // there are 2^9 = 512 PTEs in a page table.
  // 遍历一个页表页的PTE表项 
  for (int i = 0; i < 512; i++) {
    pte_t pte = pagetable[i];
    /* 判断PTE的Flag位，如果还有下一级页表(即当前是根页表或次页表)，
       则递归调用freewalk释放页表项，并将对应的PTE清零 */
    if ((pte & PTE_V) && (pte & (PTE_R | PTE_W | PTE_X)) == 0) {
      // PTE_V:表示页表项有效（Valid）。
      // 如果当前页表项有效 (PTE_V)，但没有 R/W/X 权限，则说明这个页表项指向下一级页表（而不是叶子页）
      // this PTE points to a lower-level page table.
      uint64 child = PTE2PA(pte); // 将PTE转为为物理地址
      freewalk((pagetable_t)child);
      pagetable[i] = 0;
    } else if (pte & PTE_V) {
      /* 如果叶子页表的虚拟地址还有映射到物理地址，报错panic。
         因为调用freewalk之前应该会先uvmunmap释放物理内存 */
      panic("freewalk: leaf");
    }
  }
  kfree((void *)pagetable);
}

// 页表打印功能 
void vmprint(pagetable_t pgtbl, int depth, uint64 va_last){
  // there are 2^9 = 512 PTEs in a page table.
  // 打印 vmprint 的参数，即获得的页表参数具体的值
  // printf("page table 0x%p\n",(void*)pgtbl);
  // 遍历一个页表页的PTE表项 
  int flag = 0;
  for (int i = 0; i < 512; i++) {
    // printf("||");
    pte_t pte = pgtbl[i];
    if (pte & PTE_V){
      
      if (depth==0) {
        if (flag ==0) {
          flag = 1;
          printf("page table %p\n",(void*)(pgtbl));
        }
        printf("||");
      }
      else if (depth==1) printf("||   ||");
      else if (depth==2) printf("||   ||   ||");
    }
    /* 判断PTE的Flag位，如果还有下一级页表(即当前是根页表或次页表)，
       则递归调用freewalk释放页表项，并将对应的PTE清零 */
    if ((pte & PTE_V) && (pte & (PTE_R | PTE_W | PTE_X)) == 0) {
      uint64 va = va_last + i;
      va = va << 9;
      // 非叶子结点
      char flags[5] = "rwxu";  // 定义一个字符数组，注意要给空间留出位置（包括 '\0'）
      if (!(pte & PTE_R)) flags[0] = '-';
      if (!(pte & PTE_W)) flags[1] = '-';
      if (!(pte & PTE_X)) flags[2] = '-';
      if (!(pte & PTE_U)) flags[3] = '-';
      // uint64 va = (i << 12);  // 计算虚拟地址的索引部分，假设是 12 位偏移
      uint64 child = PTE2PA(pte); // 将PTE转为物理地址
      // printf("idx: : pa: , flags: %s\n", flags);  // 打印字符串
      printf("idx: %d: pa: %p, flags: %s\n", i, (void*)child, flags);  // 打印字符串
      vmprint((pagetable_t)child,depth+1, va);
    } else if (pte & PTE_V) {
      // 叶子结点
      // 计算虚拟地址
      uint64 va = va_last + i;
      va = va << 12;

      char flags[5] = "rwxu";  // 定义一个字符数组，注意要给空间留出位置（包括 '\0'）
      if (!(pte & PTE_R)) flags[0] = '-';
      if (!(pte & PTE_W)) flags[1] = '-';
      if (!(pte & PTE_X)) flags[2] = '-';
      if (!(pte & PTE_U)) flags[3] = '-';

      uint64 child = PTE2PA(pte); // 将PTE转为物理地址
      printf("idx: %d: va: %p -> pa: %p, flags: %s\n", i, (void*)va, (void*)child, flags);  // 打印字符串

    }
  }
}

// Free user memory pages,
// then free page-table pages.
void uvmfree(pagetable_t pagetable, uint64 sz) {
  if (sz > 0) uvmunmap(pagetable, 0, PGROUNDUP(sz) / PGSIZE, 1);
  freewalk(pagetable);
}

// Given a parent process's page table, copy
// its memory into a child's page table.
// Copies both the page table and the
// physical memory.
// returns 0 on success, -1 on failure.
// frees any allocated pages on failure.
int uvmcopy(pagetable_t old, pagetable_t new, uint64 sz) {
  pte_t *pte;
  uint64 pa, i;
  uint flags;
  char *mem;

  for (i = 0; i < sz; i += PGSIZE) {
    if ((pte = walk(old, i, 0)) == 0) panic("uvmcopy: pte should exist");
    if ((*pte & PTE_V) == 0) panic("uvmcopy: page not present");
    pa = PTE2PA(*pte);
    flags = PTE_FLAGS(*pte);
    if ((mem = kalloc()) == 0) goto err;
    memmove(mem, (char *)pa, PGSIZE);
    if (mappages(new, i, PGSIZE, (uint64)mem, flags) != 0) {
      kfree(mem);
      goto err;
    }
  }
  return 0;

err:
  uvmunmap(new, 0, i / PGSIZE, 1);
  return -1;
}

// mark a PTE invalid for user access.
// used by exec for the user stack guard page.
void uvmclear(pagetable_t pagetable, uint64 va) {
  pte_t *pte;

  pte = walk(pagetable, va, 0);
  if (pte == 0) panic("uvmclear");
  *pte &= ~PTE_U;
}

// Copy from kernel to user.
// Copy len bytes from src to virtual address dstva in a given page table.
// Return 0 on success, -1 on error.
int copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len) {
  uint64 n, va0, pa0;

  while (len > 0) {
    va0 = PGROUNDDOWN(dstva);
    pa0 = walkaddr(pagetable, va0);
    if (pa0 == 0) return -1;
    n = PGSIZE - (dstva - va0);
    if (n > len) n = len;
    memmove((void *)(pa0 + (dstva - va0)), src, n);

    len -= n;
    src += n;
    dstva = va0 + PGSIZE;
  }
  return 0;
}

// Copy from user to kernel.
// Copy len bytes to dst from virtual address srcva in a given page table.
// Return 0 on success, -1 on error.
/*
这段代码是一个 从用户空间复制数据到内核空间 的函数 copyin，
其目的是从用户空间的虚拟地址 srcva 复制 len 字节的数据到内核空间的目标地址 dst。
*/
int copyin(pagetable_t pagetable, char *dst, uint64 srcva, uint64 len) {
  int ret;
  // 修改sstatus寄存器的SUM位，允许内核访问用户空间
  w_sstatus(r_sstatus() | SSTATUS_SUM);
  ret = copyin_new(pagetable, dst, srcva, len);
  // 调用copyin_new()或copyinstr_new()之后清除SUM位
  w_sstatus(r_sstatus() & ~SSTATUS_SUM);
  return ret;

  // uint64 n, va0, pa0;

  // while (len > 0) {
  //   va0 = PGROUNDDOWN(srcva);
  //   pa0 = walkaddr(pagetable, va0);
  //   if (pa0 == 0) return -1;
  //   n = PGSIZE - (srcva - va0);
  //   if (n > len) n = len;
  //   memmove(dst, (void *)(pa0 + (srcva - va0)), n);

  //   len -= n;
  //   dst += n;
  //   srcva = va0 + PGSIZE;
  // }
  // return 0;
}

// Copy a null-terminated string from user to kernel.
// Copy bytes to dst from virtual address srcva in a given page table,
// until a '\0', or max.
// Return 0 on success, -1 on error.
// 从用户空间（user space）复制一个以 null 终止的字符串（C-style string）到内核空间（kernel space）。
int copyinstr(pagetable_t pagetable, char *dst, uint64 srcva, uint64 max) {

  int ret;
  w_sstatus(r_sstatus() | SSTATUS_SUM);
  ret = copyinstr_new(pagetable, dst, srcva, max);
  w_sstatus(r_sstatus() & ~SSTATUS_SUM);

  return ret;

  // uint64 n, va0, pa0;
  // int got_null = 0;

  // while (got_null == 0 && max > 0) {
  //   va0 = PGROUNDDOWN(srcva);
  //   pa0 = walkaddr(pagetable, va0);
  //   if (pa0 == 0) return -1;
  //   n = PGSIZE - (srcva - va0);
  //   if (n > max) n = max;

  //   char *p = (char *)(pa0 + (srcva - va0));
  //   while (n > 0) {
  //     if (*p == '\0') {
  //       *dst = '\0';
  //       got_null = 1;
  //       break;
  //     } else {
  //       *dst = *p;
  //     }
  //     --n;
  //     --max;
  //     p++;
  //     dst++;
  //   }

  //   srcva = va0 + PGSIZE;
  // }
  // if (got_null) {
  //   return 0;
  // } else {
  //   return -1;
  // }
}

// check if use global kpgtbl or not
int test_pagetable() {
  uint64 satp = r_satp();
  uint64 gsatp = MAKE_SATP(kernel_pagetable);
  printf("test_pagetable: %d\n", satp != gsatp);
  return satp != gsatp;
}

void sync_pagetable(pagetable_t user_pagetable, pagetable_t kernel_pagetable){

  pagetable_t user_pa = (pagetable_t)PTE2PA(user_pagetable[0]);
  pagetable_t kernel_pa = (pagetable_t)PTE2PA(kernel_pagetable[0]);

  for(int i=0; i<0x60; i++) {
    kernel_pa[i] = user_pa[i];
  }

}
